Skip to content

Add streaming inference API#46

Open
masquare wants to merge 4 commits into
mainfrom
feature/stream-inference
Open

Add streaming inference API#46
masquare wants to merge 4 commits into
mainfrom
feature/stream-inference

Conversation

@masquare

Copy link
Copy Markdown
Collaborator

♻️ Current situation & Problem

The current inference API (generate() / eval_prompt()) returns results only after the entire sequence has been generated. For interactive applications and real-time monitoring dashboards, users must wait for the full generation to complete before seeing any output. This prevents use cases that benefit from incremental token delivery, such as live clinical decision support or streaming analytics pipelines.

⚙️ Release Notes

  • Add stream_generate() and stream_prompt() methods to TimeSeriesLLM, OpenTSLMSP, and OpenTSLMFlamingo for token-by-token streaming inference via Python iterators.
  • Shared streaming helpers (_validate_streaming_batch, _iterate_streamer) in the TimeSeriesLLM base class handle threading and error propagation.
  • Fix resize_token_embeddings to pass mean_resizing=False, preventing failures when the model is initialized on meta tensors in low-memory environments.
from opentslm.model.llm.OpenTSLMSP import OpenTSLMSP

model = OpenTSLMSP(llm_id="google/gemma-3-1b-it", device="cpu")

for token in model.stream_prompt(prompt, max_new_tokens=200):
    print(token, end="", flush=True)

Breaking changes

  • OpenTSLMFlamingo.generate() and compute_loss() now use inputs_embeds instead of lang_x/vision_x kwargs internally. Public API is unchanged.

📚 Documentation

The new stream_generate() and stream_prompt() methods follow the same conventions as the existing generate() and eval_prompt() methods. In-line docstrings and type annotations are provided. The base class defines the interface and shared utilities; each subclass implements the architecture-specific generation logic.

✅ Testing

A comprehensive test suite is added in test/test_stream_inference.py (346 lines) covering:

  • Token chunk yielding and empty-string filtering for both OpenTSLMSP and OpenTSLMFlamingo
  • Error propagation from the generation thread to the caller
  • Batch-size validation (streaming is single-sample only)
  • stream_prompt end-to-end flow (prompt conversion, normalization, eval mode)
  • Conditioned layer cleanup on success and failure (Flamingo)
  • mean_resizing=False is passed during init for both architectures

Code of Conduct & Contributing Guidelines

By creating and submitting this pull request, you agree to follow our Code of Conduct and Contributing Guidelines:

@masquare masquare requested review from RealLast and ThomasKaar and removed request for RealLast April 14, 2026 06:30
@coderabbitai

coderabbitai Bot commented Apr 14, 2026

Copy link
Copy Markdown

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: d075935d-cec7-428e-bc11-53ef073f35de

📥 Commits

Reviewing files that changed from the base of the PR and between df5110d and 2cf872f.

📒 Files selected for processing (1)
  • src/opentslm/model/llm/OpenTSLMFlamingo.py

📝 Walkthrough

Walkthrough

Added streaming generation interfaces and implementations to TimeSeriesLLM, OpenTSLMSP, and OpenTSLMFlamingo; refactored Flamingo to build and use token embeddings with explicit vision encoding and media-conditioning; adjusted tokenizer embedding resize behavior; added unit tests and streaming utilities.

Changes

Cohort / File(s) Summary
Streaming Core & Utilities
src/opentslm/model/llm/TimeSeriesLLM.py
Added abstract stream_generate/stream_prompt signatures, _validate_streaming_batch to enforce single-sample streams, and _iterate_streamer which runs generation in a daemon thread and yields streamer tokens with cleanup and error propagation.
SP Model: streaming
src/opentslm/model/llm/OpenTSLMSP.py
Added stream_generate and stream_prompt using TextIteratorStreamer and _iterate_streamer; enforce batch validation; create streamer with skip_prompt=True; changed resize_token_embeddings(..., mean_resizing=False) and removed unused LoRA param construction; expanded typing/imports.
Flamingo Model: streaming & embed refactor
src/opentslm/model/llm/OpenTSLMFlamingo.py
Refactored generation/loss to build inputs_embeds via _build_input_embeddings and _forward_with_embeddings; added _condition_media_locations; explicit vision encoding _encode_vision_x(images) before generation; new stream_generate and stream_prompt using TextIteratorStreamer; ensure conditioned decoder layers cleared in finally; tokenizer resize uses mean_resizing=False; adjusted checkpoint warning prints.
Tests: streaming behavior
test/test_stream_inference.py
New tests with FakeStreamer and monkeypatches verifying streaming output, exception propagation, batch-size validation, prompt handling/aggregation, model eval state, conditioned-layer cleanup, and tokenizer mean_resizing=False calls for both models.
Dependency constraint
pyproject.toml
Bumped transformers minimum from >=4.25 to >=4.46.

Sequence Diagram(s)

sequenceDiagram
    actor Client
    participant Model as LLM Wrapper
    participant Encoder as Vision/LLM Encoder
    participant Streamer as TextIteratorStreamer
    participant Thread as WorkerThread

    Client->>Model: stream_generate(batch, max_new_tokens, ...)
    Model->>Model: _validate_streaming_batch(batch)
    Model->>Model: build inputs_embeds & attention_mask
    alt Flamingo multimodal path
        Model->>Encoder: _encode_vision_x(images)
        Model->>Model: _condition_media_locations(input_ids)
    end
    Model->>Streamer: instantiate TextIteratorStreamer(skip_prompt=True)
    Model->>Thread: start daemon running generate_fn
    Thread->>Encoder: call model.generate(inputs_embeds=..., streamer=Streamer)
    Encoder-->>Streamer: push token chunks
    loop stream tokens
        Streamer-->>Model: yield text chunk
        Model-->>Client: yield text chunk
    end
    Model->>Thread: join thread (finally)
    Model->>Model: clear conditioned layers (if any)
    Model-->>Client: finish
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 12.50% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Add streaming inference API' accurately and concisely summarizes the main change: adding streaming methods to the inference API across multiple model classes.
Description check ✅ Passed The description is directly related to the changeset, providing problem context, release notes with concrete examples, breaking changes disclosure, documentation approach, and a comprehensive testing summary.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feature/stream-inference

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@masquare masquare requested a review from RealLast April 14, 2026 06:31

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/opentslm/model/llm/OpenTSLMFlamingo.py`:
- Around line 296-308: The sync generation path may leak conditioned media state
because self.model.lang_encoder.clear_conditioned_layers() is only called after
a successful lang_encoder.generate; wrap the conditioned cleanup in a finally
block so clear_conditioned_layers() always runs even if generate() raises.
Specifically, in _forward_with_embeddings()/the sync branch where you call
self.model._encode_vision_x(...), self._condition_media_locations(input_ids) and
then gen_ids = self.model.lang_encoder.generate(...), ensure you call
self.model.lang_encoder.clear_conditioned_layers() inside a finally (or
try/finally) surrounding the generate() call so conditioned layers are cleared
on success or exception. Use the same cleanup pattern as stream_generate() to
avoid leaking state between requests.
- Around line 182-191: The loop in OpenTSLMFlamingo._condition_media_locations
currently accesses decoder layers via
self.model.lang_encoder.get_decoder().layers which hard-codes a Llama-style
accessor; change it to use the model-specific accessor
self.model.lang_encoder._get_decoder_layers() so the dynamic
decoder_layers_attr_name set by lang_encoder.set_decoder_layers_attr_name() is
respected (see the pattern used in
TimeSeriesFlamingoWithTrainableEncoder._get_decoder_layers()); update the loop
to iterate over _get_decoder_layers() and leave the rest of the method
(media_locations/attend_previous and per-layer calls to
condition_media_locations and condition_attend_previous) unchanged.

In `@src/opentslm/model/llm/OpenTSLMSP.py`:
- Around line 7-8: The code uses TextIteratorStreamer and the mean_resizing
argument on resize_token_embeddings which are not available in transformers
4.25.0; update the project dependency to a transformers version that includes
these features (at least >=4.28.0 for TextIteratorStreamer and a more recent
release that provides mean_resizing—use a pinned range like
"transformers>=4.28.0,<5" or the specific version you have validated), then run
tests; update the requirements/pyproject entry and any lockfile, and re-run CI;
verify imports of TextIteratorStreamer and calls to
AutoModelForCausalLM.resize_token_embeddings(...) in the OpenTSLMSP initializer
and the resize_token_embeddings usage in OpenTSLMFlamingo (lines referenced in
the review) work without errors.

In `@src/opentslm/model/llm/TimeSeriesLLM.py`:
- Around line 49-54: The abstract TimeSeriesLLM.stream_prompt signature needs to
accept generation kwargs so callers typed against TimeSeriesLLM won't get
unexpected keyword errors; update the method signature on
TimeSeriesLLM.stream_prompt to include a generate_kwargs parameter (e.g.,
generate_kwargs: Optional[Dict[str, Any]] = None) and adjust the
NotImplementedError message if desired; ensure the symbol name stream_prompt on
class TimeSeriesLLM matches the subclasses OpenTSLMSP and OpenTSLMFlamingo so
their existing parameters like temperature or stream_timeout are supported when
typed against the base class.
- Around line 64-88: The unconditional thread.join() blocks cancel/timeout
paths; change _iterate_streamer so it does not perform a blocking join: remove
the unconditional thread.join() and either skip joining entirely (relying on
daemon=True) or use a non-blocking join (thread.join(0) or
thread.join(timeout=0)) so the caller can return immediately on
timeout/early-consume; keep the error check (if error is not None: raise error)
after the non-blocking join behavior so exceptions from runner are still
propagated without waiting for a stuck generation thread.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: d2888ce6-9d11-47c4-834b-2fd361f07c2d

📥 Commits

Reviewing files that changed from the base of the PR and between 104013b and b6899f6.

📒 Files selected for processing (4)
  • src/opentslm/model/llm/OpenTSLMFlamingo.py
  • src/opentslm/model/llm/OpenTSLMSP.py
  • src/opentslm/model/llm/TimeSeriesLLM.py
  • test/test_stream_inference.py

Comment thread src/opentslm/model/llm/OpenTSLMFlamingo.py
Comment thread src/opentslm/model/llm/OpenTSLMFlamingo.py Outdated
Comment thread src/opentslm/model/llm/OpenTSLMFlamingo.py
Comment thread src/opentslm/model/llm/OpenTSLMSP.py
Comment thread src/opentslm/model/llm/TimeSeriesLLM.py
Comment thread src/opentslm/model/llm/TimeSeriesLLM.py
- Use _get_decoder_layers() instead of hard-coded get_decoder().layers
  for multi-architecture support in _condition_media_locations
- Wrap clear_conditioned_layers() in finally block in generate() to
  prevent state leaks on errors
- Add **generate_kwargs to base TimeSeriesLLM.stream_prompt() signature
  to match subclass interfaces
- Bump transformers minimum version from >=4.25 to >=4.46 for
  TextIteratorStreamer and mean_resizing support

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/opentslm/model/llm/OpenTSLMFlamingo.py`:
- Around line 111-121: The constructor currently unconditionally calls
_infer_decoder_layers_attr_name and overwrites any explicit
decoder_layers_attr_name; change the logic to honor a caller-provided override
by only calling _infer_decoder_layers_attr_name when decoder_layers_attr_name is
not provided (e.g., is None/empty), and otherwise use the provided
decoder_layers_attr_name when calling lang_encoder.set_decoder_layers_attr_name;
keep the existing resize_token_embeddings call as-is.
- Around line 170-180: The current sequence calls
self.model._encode_vision_x(...) and self._condition_media_locations(...) which
mutate per-layer conditioned state but only clears that state in a finally after
invoking FlamingoLMMixin/lang_encoder.forward; if either setup step raises the
clear is skipped and state leaks. Wrap the setup and forward so that
clear_conditioned_layers() is registered to run before any mutation can occur:
call clear_conditioned_layers() in a finally that covers _encode_vision_x and
_condition_media_locations (i.e., start the try/finally before those calls),
then perform inputs_embeds building and the call to super(...).forward inside
the try, and always call self.model.lang_encoder.clear_conditioned_layers() in
the finally. Apply the same pattern to the other occurrences referenced (around
the blocks at ~297-310 and ~342-355).
- Around line 329-357: The torch.inference_mode() context is currently outside
run_generation and therefore not active in worker threads; move the
torch.inference_mode() call inside the run_generation function so that
_encode_vision_x, _condition_media_locations and model.lang_encoder.generate
execute under inference mode. Specifically, wrap the body of run_generation (the
calls to self.model._encode_vision_x, self._condition_media_locations, and
self.model.lang_encoder.generate, and the finally block that calls
self.model.lang_encoder.clear_conditioned_layers) in a with
torch.inference_mode(): block so that streaming generation invoked via
_iterate_streamer and the TextIteratorStreamer runs without creating autograd
state.
- Around line 301-313: The generation path currently passes inputs_embeds alone
to self.model.lang_encoder.generate which (on transformers >=4.46) returns only
new tokens, so the existing slice answer_only_ids = gen_ids[:,
input_ids.shape[1]:] and TextIteratorStreamer(skip_prompt=True) drop real
generated tokens; fix by either (A) pass input_ids along with inputs_embeds into
self.model.lang_encoder.generate so gen_ids includes the prompt (update the
generate call at the inputs_embeds usage and keep answer_only_ids slicing and
TextIteratorStreamer(skip_prompt=True)), or (B) treat gen_ids as "new tokens
only" by removing the slice (set answer_only_ids = gen_ids) and change
TextIteratorStreamer to skip_prompt=False (update the code paths that create
answer_only_ids and the TextIteratorStreamer instantiation). Ensure changes
touch the generate call (self.model.lang_encoder.generate), answer_only_ids
assignment, and TextIteratorStreamer(skip_prompt=...) usages so behavior is
consistent.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 86a257c5-784f-4c72-b892-64919f231460

📥 Commits

Reviewing files that changed from the base of the PR and between b6899f6 and df5110d.

⛔ Files ignored due to path filters (1)
  • uv.lock is excluded by !**/*.lock
📒 Files selected for processing (3)
  • pyproject.toml
  • src/opentslm/model/llm/OpenTSLMFlamingo.py
  • src/opentslm/model/llm/TimeSeriesLLM.py
✅ Files skipped from review due to trivial changes (1)
  • pyproject.toml

Comment thread src/opentslm/model/llm/OpenTSLMFlamingo.py
Comment thread src/opentslm/model/llm/OpenTSLMFlamingo.py Outdated
Comment thread src/opentslm/model/llm/OpenTSLMFlamingo.py
Comment thread src/opentslm/model/llm/OpenTSLMFlamingo.py
- Move _encode_vision_x and _condition_media_locations inside try blocks
  so clear_conditioned_layers always runs even if setup steps raise
- Add torch.inference_mode() inside run_generation thread since
  inference_mode is thread-local and does not propagate to new threads
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant